In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.datasets as dset
from PIL import Image
from torchvision import transforms
import torchvision.utils as vutils
import numpy as np
import torch.nn.functional as F
from math import log2
import math
In [ ]:
# Root directory for dataset
dataroot = "C:\.Python Projects\AnimeGAN\data"
# Number of workers for dataloader
workers = 14
# Set True to prevent releasing and reassigning workers between epochs
persistent_workers = True
# Number of batches to prefetch per worker ( workers * prefect_factor = Number of Batches preloaded )
prefetch_factor = 4
# Batch size during training
batch_sizes = [64, 64, 64, 64, 40, 16, 8, 4] # 512: [64, 64, 64, 64, 10, 6, 4, 2], 256: [64, 64, 64, 64, 32, 16, 8, 4]
# Image size to train up to (inclusive). Paper uses 1024
image_size = 128
assert image_size >= 4, f"{image_size} is not greater than or equal to 4!"
assert image_size <= 1024, f"{image_size} is not less than or equal to 1024!"
assert math.ceil(log2(image_size)) == math.floor(log2(image_size)), f"{image_size} is not a power of 2!"
# Image size to start training at
start_train_at = 4
assert image_size >= 4, f"{start_train_at} is not greater than or equal to 4!"
assert start_train_at <= image_size, f"{start_train_at} is not less than or equal to {image_size}"
assert math.ceil(log2(start_train_at)) == math.floor(log2(start_train_at)), f"{start_train_at} is not a power of 2!"
# Device to push to
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
pin_memory = True if device.type == "cuda" else False
# Size of latent vector, Paper uses 512 for z_dim and in_channels
z_dim = 256
in_channels = 256
# Learning rate for optimizers. Paper uses 1e-3
lr = 1e-3
lambda_gp = 10
# Number of steps to reach desired image size
num_steps = int(log2(image_size / 4)) + 1
# Progressive Epochs
# 4x4: Train the model to see 800k images.
# 8x8, 16x16, ..., img_size x img_size: Train model to see 800k images fading in the new layer and 800k images to stabilize.
#
# 4x4 epochs: 800,000 / dataset_size
# onwards... 2 times 4x4 epochs
prog_epochs = [32] + [64] * (num_steps - 1)
# Used to create the layers progressively.
factors = [1, 1, 1, 1, 1/2, 1/4, 1/8, 1/16, 1/32]
# Fixed noise to monitor progression of model
fixed_noise = torch.randn(64, z_dim, 1, 1).to(device)
# Display images after each epoch
display_images = False
# Use pretrained model
use_pretrained = False
# Set true to show every step with tqdm
update_last = True
start_epoch = 0
step = int(log2(start_train_at / 4))
if use_pretrained:
path = "../PATH_TO_CHECKPOINT.pth"
checkpoint = torch.load(path)
batch_sizes = checkpoint["batch_sizes"]
start_train_at = checkpoint["start_training_at"]
fixed_noise = checkpoint["fixed_noise"]
z_dim = checkpoint["z_dim"]
in_channels = checkpoint["in_channels"]
step = int(log2(start_train_at / 4))
start_epoch = checkpoint["epoch"]
In [3]:
class WSConv2d(nn.Module):
def __init__(self, input_channel, out_channel, kernel_size=3, stride=1, padding=1, gain=2):
super(WSConv2d, self).__init__()
self.conv = nn.Conv2d(input_channel, out_channel, kernel_size, stride, padding)
self.scale = (gain / (input_channel * (kernel_size ** 2))) ** 0.5
self.bias = self.conv.bias
self.conv.bias = None
nn.init.normal_(self.conv.weight)
nn.init.zeros_(self.bias)
def forward(self, x):
return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)
In [4]:
class PixelNorm(nn.Module):
def __init__(self):
super(PixelNorm, self).__init__()
self.epsilon = 1e-8
def forward(self, x):
return x / torch.sqrt( torch.mean( x ** 2, dim=1, keepdim=True ) + self.epsilon )
In [5]:
class ConvBlock(nn.Module):
def __init__(self, input_channel, out_channel, pixel_norm=True):
super(ConvBlock, self).__init__()
self.conv1 = WSConv2d(input_channel, out_channel)
self.conv2 = WSConv2d(out_channel, out_channel)
self.leaky = nn.LeakyReLU(0.2)
self.pn = PixelNorm()
self.use_pn = pixel_norm
def forward(self, x):
x = self.leaky(self.conv1(x))
x = self.pn(x) if self.use_pn else x
x = self.leaky(self.conv2(x))
x = self.pn(x) if self.use_pn else x
return x
In [6]:
class Generator(nn.Module):
def __init__(self, z_dim, in_channels, img_channels=3):
super(Generator, self).__init__()
self.initial = nn.Sequential(
PixelNorm(),
nn.ConvTranspose2d(z_dim, in_channels, 4, 1, 0), # 1x1 -> 4x4
nn.LeakyReLU(0.2),
WSConv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2),
PixelNorm()
)
self.initial_rgb = WSConv2d(in_channels, img_channels, kernel_size=1, stride=1, padding=0)
self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([self.initial_rgb])
for i in range(len(factors) - 1):
conv_in_c = int(in_channels * factors[i])
conv_out_c = int(in_channels * factors[i+1])
self.prog_blocks.append(ConvBlock(conv_in_c, conv_out_c))
self.rgb_layers.append(WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0))
def fade_in(self, alpha, upscaled, generated):
return torch.tanh(alpha * generated + (1 - alpha) * upscaled)
def forward(self, x, alpha, steps):
out = self.initial(x)
if steps == 0:
return self.initial_rgb(out)
for i in range(steps):
upscaled = F.interpolate(out, scale_factor=2, mode="nearest")
out = self.prog_blocks[i](upscaled)
final_upscaled = self.rgb_layers[steps-1](upscaled)
final_out = self.rgb_layers[steps](out)
return self.fade_in(alpha, final_upscaled, final_out)
In [7]:
class Discriminator(nn.Module):
def __init__(self, in_channels, img_channels=3):
super(Discriminator, self).__init__()
self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
self.leaky = nn.LeakyReLU(0.2)
for i in range(len(factors) - 1, 0, -1):
conv_in_c = int(in_channels * factors[i])
conv_out_c = int(in_channels * factors[i-1])
self.prog_blocks.append(ConvBlock(conv_in_c, conv_out_c, pixel_norm=False))
self.rgb_layers.append(WSConv2d(img_channels, conv_in_c, kernel_size=1, stride=1, padding=0))
# 4x4 img res
self.initial_rgb = WSConv2d(img_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.rgb_layers.append(self.initial_rgb)
self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
self.final_block = nn.Sequential(
WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1, stride=1),
nn.LeakyReLU(0.2),
WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
nn.LeakyReLU(0.2),
WSConv2d(in_channels, 1, kernel_size=1, padding=0, stride=1),
)
def fade_in(self, alpha, downscaled, out):
return alpha * out + (1 - alpha) * downscaled
def minibatch_std(self, x: torch.Tensor):
batch_statistics = (
torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
)
return torch.cat([x, batch_statistics], dim=1)
def forward(self, x, alpha, steps):
cur_step = len(self.prog_blocks) - steps
out = self.leaky(self.rgb_layers[cur_step](x))
if steps == 0: # i.e, image is 4x4
out = self.minibatch_std(out)
return self.final_block(out).view(out.shape[0], -1)
downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
out = self.avg_pool(self.prog_blocks[cur_step](out))
out = self.fade_in(alpha, downscaled, out)
for step in range(cur_step + 1, len(self.prog_blocks)):
out = self.prog_blocks[step](out)
out = self.avg_pool(out)
out = self.minibatch_std(out)
return self.final_block(out).view(out.shape[0], -1)
In [8]:
gen = Generator(z_dim=z_dim, in_channels=in_channels)
disc = Discriminator(in_channels=in_channels)
for img_size in [4, 8, 16, 32, 64, 128, 256, 512]:
num_steps = int(log2(img_size / 4))
x = torch.randn((1, z_dim, 1, 1))
z = gen(x, 0.5, steps=num_steps)
assert z.shape == (1, 3, img_size, img_size)
out = disc(z, alpha=0.5, steps=num_steps)
print(f"Success! at img size: {img_size}")
C:\Users\palge\AppData\Local\Temp\ipykernel_14344\2363565070.py:31: UserWarning: std(): degrees of freedom is <= 0. Correction should be strictly less than the reduction factor (input numel divided by output numel). (Triggered internally at ..\aten\src\ATen\native\ReduceOps.cpp:1760.) torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
Success! at img size: 4 Success! at img size: 8 Success! at img size: 16 Success! at img size: 32 Success! at img size: 64 Success! at img size: 128 Success! at img size: 256 Success! at img size: 512
In [9]:
def get_loader(image_size):
transform = transforms.Compose([
transforms.Resize((image_size, image_size)), # Resize to a standard size (can be adjusted)
transforms.ToTensor(), # Convert the image to a PyTorch tensor
transforms.RandomHorizontalFlip(p=0.5),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), # Normalize to [-1, 1]
])
batch_size = batch_sizes[int(log2(image_size / 4))]
dataset = dset.ImageFolder(root=dataroot, transform=transform)
dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
num_workers=workers,
drop_last=True,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor)
return dataloader, dataset
In [10]:
def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
BATCH_SIZE, C, H, W = real.shape
beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
interpolated_images = real * beta + fake.detach() * (1 - beta)
interpolated_images.requires_grad_(True)
# Calculate critic scores
mixed_scores = critic(interpolated_images, alpha, train_step)
# Take the gradient of the scores with respect to the images
gradient = torch.autograd.grad(
inputs=interpolated_images,
outputs=mixed_scores,
grad_outputs=torch.ones_like(mixed_scores),
create_graph=True,
retain_graph=True,
)[0]
gradient = gradient.view(gradient.shape[0], -1)
gradient_norm = gradient.norm(2, dim=1)
gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
return gradient_penalty
In [11]:
from tqdm import tqdm
def train(disc, gen, loader, dataset, step, alpha, opt_disc, opt_gen, epoch, num_epochs):
loop = tqdm(loader, total=len(loader), miniters=50, desc=f"Epoch [{epoch + 1}/{num_epochs}]", leave=update_last)
for real, _ in loop:
#if batch_idx == 10:
# break
real = real.to(device)
cur_batch_size = real.shape[0]
# Generate noise and fake data
noise = torch.randn(cur_batch_size, z_dim, 1, 1).to(device)
fake = gen(noise, alpha, step)
# Train Discriminator
disc_real = disc(real, alpha, step)
disc_fake = disc(fake.detach(), alpha, step)
gp = gradient_penalty(disc, real, fake, alpha, step, device)
loss_disc = (
-(torch.mean(disc_real) - torch.mean(disc_fake))
+ lambda_gp * gp
+ (0.001 * torch.mean(disc_real ** 2))
)
opt_disc.zero_grad()
loss_disc.backward()
opt_disc.step()
# Train Generator
gen_fake = disc(fake, alpha, step)
loss_gen = -torch.mean(gen_fake)
opt_gen.zero_grad()
loss_gen.backward()
opt_gen.step()
# Update alpha for progressive growing
alpha += cur_batch_size / ((prog_epochs[step] * 0.5) * len(dataset))
alpha = min(alpha, 1)
return alpha
In [12]:
import matplotlib.pyplot as plt
gen = Generator(z_dim=z_dim, in_channels=in_channels).to(device)
disc = Discriminator(in_channels=in_channels).to(device)
optimizer_gen = optim.Adam(gen.parameters(), lr=lr, betas=(0.0, 0.99))
optimizer_disc = optim.Adam(disc.parameters(), lr=lr, betas=(0.0, 0.99))
if use_pretrained:
gen.load_state_dict(checkpoint['gen_state'])
disc.load_state_dict(checkpoint['disc_state'])
optimizer_gen.load_state_dict(checkpoint['gen_optim'])
optimizer_disc.load_state_dict(checkpoint['disc_optim'])
fakes = []
if use_pretrained:
fakes = checkpoint["fakes"]
In [13]:
gen.train()
disc.train()
change_alpha = use_pretrained
for num_epochs in prog_epochs[step:]:
if start_train_at > image_size:
break
alpha = 1e-5
if change_alpha:
alpha = checkpoint['alpha']
change_alpha = False
loader, dataset = get_loader(4 * 2 ** step)
print(f"Current image size: {4 * 2 ** step}")
for epoch in range(start_epoch, num_epochs):
alpha = train(
disc,
gen,
loader,
dataset,
step,
alpha,
optimizer_disc,
optimizer_gen,
epoch,
num_epochs,
)
with torch.no_grad():
img = gen(fixed_noise, alpha, step) * 0.5 + 0.5
fakes.append(img)
if display_images:
img = img.cpu().detach()
img = (img - img.min()) / (img.max() - img.min())
# Display each image
fig, axes = plt.subplots(8, 8, figsize=(8, 8)) # Create a 8x8 grid
axes = axes.flatten()
for i, ax in enumerate(axes):
image = img[i].permute(1, 2, 0).numpy() # Rearrange dimensions to (H, W, C)
ax.imshow(image)
ax.axis('off')
plt.tight_layout()
plt.show()
if epoch + 1 != num_epochs:
torch.save({
'batch_sizes': batch_sizes,
'start_training_at': 4 * 2 ** step,
'alpha': alpha,
'fixed_noise': fixed_noise,
'z_dim': z_dim,
'in_channels': in_channels,
'epoch': epoch + 1,
'fakes': fakes,
'gen_state': gen.state_dict(),
'disc_state': disc.state_dict(),
'gen_optim': optimizer_gen.state_dict(),
'disc_optim': optimizer_disc.state_dict(),
}, f"models/training_imgsize_{4 * 2 ** (step)}_zdim_{z_dim}_progression.pth")
torch.save({
'batch_sizes': batch_sizes,
'start_training_at': (4 * 2 ** step) * 2,
'alpha': 1e-5,
'fixed_noise': fixed_noise,
'z_dim': z_dim,
'in_channels': in_channels,
'epoch': 0,
'fakes': fakes,
'gen_state': gen.state_dict(),
'disc_state': disc.state_dict(),
'gen_optim': optimizer_gen.state_dict(),
'disc_optim': optimizer_disc.state_dict(),
}, f"models/pretrained_imgsize_{4 * 2 ** step}_zdim_{z_dim}.pth")
step += 1
gen.eval()
disc.eval()
print("Eval mode activated")
Current image size: 64
Epoch [25/64]: 100%|██████████| 641/641 [08:46<00:00, 1.22it/s] Epoch [26/64]: 100%|██████████| 641/641 [08:00<00:00, 1.33it/s] Epoch [27/64]: 100%|██████████| 641/641 [07:12<00:00, 1.48it/s] Epoch [28/64]: 100%|██████████| 641/641 [08:01<00:00, 1.33it/s] Epoch [29/64]: 100%|██████████| 641/641 [08:00<00:00, 1.33it/s] Epoch [30/64]: 100%|██████████| 641/641 [07:28<00:00, 1.43it/s] Epoch [31/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s] Epoch [32/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s] Epoch [33/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s] Epoch [34/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s] Epoch [35/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s] Epoch [36/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s] Epoch [37/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s] Epoch [38/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s] Epoch [39/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s] Epoch [40/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s] Epoch [41/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s] Epoch [42/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s] Epoch [43/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s] Epoch [44/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s] Epoch [45/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s] Epoch [46/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s] Epoch [47/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s] Epoch [48/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s] Epoch [49/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s] Epoch [50/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s] Epoch [51/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s] Epoch [52/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s] Epoch [53/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s] Epoch [54/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s] Epoch [55/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s] Epoch [56/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s] Epoch [57/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s] Epoch [58/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s] Epoch [59/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s] Epoch [60/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s] Epoch [61/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s] Epoch [62/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s] Epoch [63/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s] Epoch [64/64]: 100%|██████████| 641/641 [07:11<00:00, 1.48it/s]
Current image size: 128
Epoch [25/64]: 100%|██████████| 1283/1283 [19:25<00:00, 1.10it/s] Epoch [26/64]: 100%|██████████| 1283/1283 [18:45<00:00, 1.14it/s] Epoch [27/64]: 100%|██████████| 1283/1283 [18:45<00:00, 1.14it/s] Epoch [28/64]: 100%|██████████| 1283/1283 [18:46<00:00, 1.14it/s] Epoch [29/64]: 100%|██████████| 1283/1283 [18:48<00:00, 1.14it/s] Epoch [30/64]: 100%|██████████| 1283/1283 [18:45<00:00, 1.14it/s] Epoch [31/64]: 100%|██████████| 1283/1283 [18:44<00:00, 1.14it/s] Epoch [32/64]: 100%|██████████| 1283/1283 [18:43<00:00, 1.14it/s] Epoch [33/64]: 100%|██████████| 1283/1283 [18:41<00:00, 1.14it/s] Epoch [34/64]: 100%|██████████| 1283/1283 [18:43<00:00, 1.14it/s] Epoch [35/64]: 100%|██████████| 1283/1283 [18:45<00:00, 1.14it/s] Epoch [36/64]: 100%|██████████| 1283/1283 [18:46<00:00, 1.14it/s] Epoch [37/64]: 100%|██████████| 1283/1283 [18:46<00:00, 1.14it/s] Epoch [38/64]: 100%|██████████| 1283/1283 [18:45<00:00, 1.14it/s] Epoch [39/64]: 100%|██████████| 1283/1283 [20:08<00:00, 1.06it/s] Epoch [40/64]: 100%|██████████| 1283/1283 [18:47<00:00, 1.14it/s] Epoch [41/64]: 100%|██████████| 1283/1283 [18:46<00:00, 1.14it/s] Epoch [42/64]: 100%|██████████| 1283/1283 [20:09<00:00, 1.06it/s] Epoch [43/64]: 100%|██████████| 1283/1283 [18:47<00:00, 1.14it/s] Epoch [44/64]: 100%|██████████| 1283/1283 [19:44<00:00, 1.08it/s] Epoch [45/64]: 100%|██████████| 1283/1283 [19:09<00:00, 1.12it/s] Epoch [46/64]: 100%|██████████| 1283/1283 [20:08<00:00, 1.06it/s] Epoch [47/64]: 100%|██████████| 1283/1283 [19:24<00:00, 1.10it/s] Epoch [48/64]: 100%|██████████| 1283/1283 [18:42<00:00, 1.14it/s] Epoch [49/64]: 100%|██████████| 1283/1283 [18:43<00:00, 1.14it/s] Epoch [50/64]: 100%|██████████| 1283/1283 [18:43<00:00, 1.14it/s] Epoch [51/64]: 100%|██████████| 1283/1283 [18:45<00:00, 1.14it/s] Epoch [52/64]: 100%|██████████| 1283/1283 [19:14<00:00, 1.11it/s] Epoch [53/64]: 100%|██████████| 1283/1283 [21:56<00:00, 1.03s/it] Epoch [54/64]: 100%|██████████| 1283/1283 [22:46<00:00, 1.07s/it] Epoch [55/64]: 100%|██████████| 1283/1283 [21:02<00:00, 1.02it/s] Epoch [56/64]: 100%|██████████| 1283/1283 [21:32<00:00, 1.01s/it] Epoch [57/64]: 100%|██████████| 1283/1283 [20:41<00:00, 1.03it/s] Epoch [58/64]: 100%|██████████| 1283/1283 [19:58<00:00, 1.07it/s] Epoch [59/64]: 100%|██████████| 1283/1283 [20:03<00:00, 1.07it/s] Epoch [60/64]: 100%|██████████| 1283/1283 [19:58<00:00, 1.07it/s] Epoch [61/64]: 100%|██████████| 1283/1283 [19:50<00:00, 1.08it/s] Epoch [62/64]: 100%|██████████| 1283/1283 [20:05<00:00, 1.06it/s] Epoch [63/64]: 100%|██████████| 1283/1283 [19:54<00:00, 1.07it/s] Epoch [64/64]: 100%|██████████| 1283/1283 [19:49<00:00, 1.08it/s]
Eval mode activated
In [65]:
img_list = []
for x in [x.detach().cpu() for x in fakes]: # To shorten number of images: [x[:8] for x in fakes]
x = torch.nn.functional.interpolate(x, size=(128, 128), mode="nearest")
img_list.append(vutils.make_grid(x, padding=2, normalize=True))
In [66]:
import matplotlib.animation as animation
from IPython.display import HTML
fig = plt.figure(figsize=(8, 8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
ani.save("anime.gif", writer='pillow', fps=10)
HTML(ani.to_jshtml())
Animation size has reached 21050564 bytes, exceeding the limit of 20971520.0. If you're sure you want a larger animation embedded, set the animation.embed_limit rc parameter to a larger value (in MB). This and further frames will be dropped.
Out[66]: